{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model \"create_model_from_pretrained\"\n",
    "By default, the model preprocessor uses 448 x 448 as the input size. To specify a different image size (e.g. 336 x 336), use the **force_img_size** argument.\n",
    "\n",
    "You can specify a cuda device by using the **device** argument, or manually move the model to a device later using **model.to(device)**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cfg = 'conch_ViT-B-16'\n",
    "checkpoint_path = './checkpoints/CONCH/pytorch_model.bin'\n",
    "model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path)\n",
    "# model, preprocess = create_model_from_pretrained(model_cfg, checkpoint_path, force_img_size=224, device='cuda:2')\n",
    "_ = model.eval()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embed images \n",
    "The **.encode_image()** method encodes a batch of images into a batch of image embeddings. Note that this function applies the contrastive learning projection head to the image and performs l2-normalization before returning the embedding, which is used for computing the similarity scores such as between images and texts. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "image = Image.open('../docs/roi1.jpg')\n",
    "image = preprocess(image).unsqueeze(0)\n",
    "print(image.shape)\n",
    "\n",
    "with torch.inference_mode():\n",
    "    image_embs = model.encode_image(image)\n",
    "    \n",
    "print(image_embs.shape)\n",
    "print(image_embs.norm(dim=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For image-only tasks, it is common to directly use the representation before the projection head and l2-normalization. This is done by setting **proj_contrast=False** and **normalize=False**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "    image_embs = model.encode_image(image, proj_contrast=False, normalize=False)\n",
    "\n",
    "print(image_embs.shape)\n",
    "print(image_embs.norm(dim=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embed texts\n",
    "The **.encode_text()** method encodes a batch of texts into a batch of l2-normalized text embeddings used for computing the similarity scores such as between images and texts. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\"H&E image of lung adenocarcinoma\",\n",
    "         \"photomicrograph of a lung squamous cell carcinoma, H&E stain\"]\n",
    "tokenizer = get_tokenizer() # load tokenizer\n",
    "text_tokens = tokenize(texts=texts, tokenizer=tokenizer) # tokenize the text\n",
    "text_embs = model.encode_text(text_tokens)\n",
    "print(text_embs.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
